import cv2
import numpy as np
from PIL import Image
import tensorflow as tf
from model import MODEL  # 导入网络模型结构
from NMS import predict  # 导入预测框处理方法
# 调用GPU加速
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# -------------------------------------------- #
# yolov4权重文件路径, 视频文件路径
# -------------------------------------------- #
yolo_weights = 'yolo4_voc_weights.h5'
video_path = 'D:/deeplearning/video/car.mp4'
cap = cv2.VideoCapture(video_path)  # 视频捕获
picture_path = 'D:/deeplearning/database/picture/moto.jpg'

# -------------------------------------------- #
# 检测视频还是图片
# -------------------------------------------- #
video = True
picture = False

# -------------------------------------------- #
# class_names: VOC数据集的分类名
# anchors: 先验框的长宽
# num_anchors: 每个网格生成几个先验框
# num_classes: 一共有几个类别
# inputs_shape: 输入图像的尺寸
# inputs: 网络输入层
# conf_thresh:分类概率小于这个值的框被删除
# nms_thresh:两个框计算交并比, iou小于这个值被保留, 删除重复的框
# -------------------------------------------- #
class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
              'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']

anchors = np.array([[12, 16],  [19, 36],  [40, 28],  [36, 75],  [76, 55],  [72, 146],  [142, 110],  [192, 243],  [459, 401]])
num_anchors = 3
num_classes = len(class_names)
input_shape = [416,416,3]
conf_thresh = 0.6
nms_thresh = 0.4

# ----------------------------------------------------- #
# 模型构造, 加载权重, 我用的VOC的权重
# ----------------------------------------------------- #
yolo_model = MODEL(input_shape, num_anchors, num_classes, summary=False)
yolo_model.load_weights(yolo_weights) 

# ----------------------------------------------------- #
# 图像预处理
# ----------------------------------------------------- # 
def preprocessing(image, inputs_shape):
    
    # opencv读入的图像是BGR格式，转换成RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # 转换成Image类型
    image = Image.fromarray(np.uint8(image)) 
    # 调整图像尺寸
    image_data = image.resize(size=(inputs_shape[1], inputs_shape[0]))
    # 改变数据类型
    image_data = np.array(image_data, dtype=np.float32)
    # 归一化
    image_data = image_data / 255.0
    # 添加batch维度 [416,416,3]==>[1,416,416,3]
    image_data = np.expand_dims(image_data, axis=0)

    return image_data

# ----------------------------------------------------- #
# 处理视频帧图像
# ----------------------------------------------------- # 
if video is True:
    
    # 处理帧图像
    while True:
        # 返回图像是否读取成功success，以及读取的帧图像img
        success, img = cap.read()
        # 将读入的图像复制一份
        frame = img.copy()
        # 原始输入图像的(h,w)
        image_shape = img.shape[0:2]
        # 图像预处理
        img = preprocessing(img, input_shape)
    
        # 返回模型输出的三个有效特征层
        p5_output, p4_output, p3_output = yolo_model.predict(img)
        
        # 看一下输出特征层对不对
        # print(p5_output.shape, p4_output.shape, p3_output.shape)
        # print(p5_output, p4_output, p3_output)
        # 整合一下输出特征层
        feats = [p5_output, p4_output, p3_output]
        
        # ----------------------------------------------------- #
        # 将输出的预测框信息解码、调整先验框、非极大值抑制
        # feats:代表模型的三个有效输出特征层
        # image_shape:原始图像的(h,w)
        # num_classes:分类类别数
        # anchors:先验框的高宽
        # conf_thresh:分类概率小于这个值的框被删除
        # nms_thresh:两个框计算交并比, iou小于这个值被保留, 删除重复的框
        # max_boxes:最大预测框数量
        # ----------------------------------------------------- #
        # predict_boxes预测框左上坐标和宽高
        # predict_score预测框的类别概率
        # predict_classes预测框所属类别的索引
        # ----------------------------------------------------- #
        predict_boxes, predict_score, predict_classes = predict(feats, image_shape, num_classes, anchors, conf_thresh, nms_thresh, max_boxes=100)
        print(predict_boxes, predict_score, predict_classes)
        
        # ----------------------------------------------------- #
        # 绘制预测框
        # ----------------------------------------------------- # 
        for box, score, class_index in zip(predict_boxes, predict_score, predict_classes):
    
            # 获取预测框的左上坐标和宽高
            box_y1, box_x1, box_y2, box_x2 = box[0], box[1], box[2], box[3]
            # 每个预测框的类别名称
            class_name = class_names[class_index]
            # 将名称和分数组合在一起, 字符串
            class_name_score = class_name + ': ' + str(round(score.numpy(),2))
            
            print('=====================================')
            print(box_x1, box_y1, box_x2, box_y2)
            print(class_name)
            print(class_name_score)
            
            # 绘制预测框
            cv2.rectangle(frame, (box_x1, box_y1), (box_x2, box_y2), color=(0,255,0), thickness=2)
            # 显示类别和概率
            cv2.putText(frame, class_name_score, (box_x1, box_y1-5), cv2.FONT_HERSHEY_COMPLEX, 1, (0,0,255), 2)
        
        # ----------------------------------------------------- #
        # 显示图像
        # ----------------------------------------------------- # 
        cv2.imshow('frame', frame)  # 传入窗口名和帧图像
        # 每帧图像滞留10毫秒后消失，按下键盘上的ESC键退出程序
        if cv2.waitKey(10) & 0xFF == 27:
            break
    
    # 释放视频资源
    cap.release()
    cv2.destroyAllWindows()


# ----------------------------------------------------- #
# 处理单张图像
# ----------------------------------------------------- # 
if picture is True:
    
    # 读取一张图像
    img = cv2.imread(picture_path)
    # 调整尺寸
    img = cv2.resize(img, (1280,720))
    # 将读入的图像复制一份
    frame = img.copy()
    # 原始输入图像的(h,w)
    image_shape = img.shape[0:2]
    # 图像预处理
    img = preprocessing(img, input_shape)

    # 返回模型输出的三个有效特征层
    p5_output, p4_output, p3_output = yolo_model.predict(img)
    
    # 整合一下输出特征层
    feats = [p5_output, p4_output, p3_output]
    
    # ----------------------------------------------------- #
    # 将输出的预测框信息解码、调整先验框、非极大值抑制
    # ----------------------------------------------------- #
    predict_boxes, predict_score, predict_classes = predict(feats, image_shape, num_classes, anchors, conf_thresh, nms_thresh, max_boxes=20)
    
    # ----------------------------------------------------- #
    # 绘制预测框
    # ----------------------------------------------------- # 
    for box, score, class_index in zip(predict_boxes, predict_score, predict_classes):

        # 获取预测框的左上坐标和宽高
        box_y1, box_x1, box_y2, box_x2 = box[0], box[1], box[2], box[3]
        # 每个预测框的类别名称
        class_name = class_names[class_index]
        # 将名称和分数组合在一起, 字符串
        class_name_score = class_name + ': ' + str(round(score.numpy(),2))

        # 绘制预测框
        cv2.rectangle(frame, (box_x1, box_y1), (box_x2, box_y2), color=(0,255,0), thickness=2)
        # 显示类别和概率
        cv2.putText(frame, class_name_score, (box_x1, box_y1-10), cv2.FONT_HERSHEY_COMPLEX, 1, (0,0,255), 2)
    
    # ----------------------------------------------------- #
    # 显示图像
    # ----------------------------------------------------- # 
    cv2.imshow('frame', frame)  # 传入窗口名和帧图像
    # 每帧图像滞留10毫秒后消失，按下键盘上的ESC键退出程序
    cv2.waitKey(0) 
    # 键盘上任意按一个图像消失
    cv2.destroyAllWindows()

